/*
 * This file is part of monitor_agent.
 * Copyright (c) 2018. Author: yinjia evoex123@gmail.com
 *
 * This program is free software: you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or (at your
 * option) any later version.  This program is distributed in the hope that it
 * will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
 * of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser
 * General Public License for more details.  You should have received a copy
 * of the GNU Lesser General Public License along with this program.  If not,
 * see <http://www.gnu.org/licenses/>.
 */

package module

import (
	"bytes"
	"crypto/hmac"
	"crypto/sha1"
	"encoding/base64"
	"fmt"
	"golang.org/x/crypto/pbkdf2"
	"math/rand"
	"regexp"
	"strconv"
)

const (
	ClientPass   = "pencil"
	ClientHeader = "biws"
	PBKDF2Length = 20
)

var ServerFinalMessage = regexp.MustCompile(`v=([^,]*)$`)

func RandStringBytesRmndr(n int) string {
	letterBytes := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
	b := make([]byte, n)
	for i := range b {
		b[i] = letterBytes[rand.Int63()%int64(len(letterBytes))]
	}
	return string(b)
}

// Client nonce: This is a value that is randomly generated by the client,
// ideally using a cryptographic random generator.
func makeClientNonce() string {
	return RandStringBytesRmndr(10)
}

func clientFirstMessageBare(cName []byte, cNonce []byte) (out []byte) {
	out = []byte("n=")
	out = append(out, cName...)
	out = append(out, ",r="...)
	out = append(out, cNonce...)
	return
}

func clientFirstMessage(cName, cNonce []byte) (out []byte) {
	out = []byte("n,,")
	out = append(out, clientFirstMessageBare(cName, cNonce)...)
	return
}

func getAttribute(message []byte, attribute byte) []byte {
	attributes := bytes.Split(message, []byte{','})

	for _, a := range attributes {
		if len(a) > 0 && a[0] == attribute {
			return a[2:]
		}
	}
	return nil
}

func scramSha1FirstMessage(cname string) ([]byte, []byte) {
	fmt.Println("scram sha-1 login")
	//cName := []byte("clientName")
	cName := []byte(cname)
	cNonce := []byte(makeClientNonce())
	cFirstMessage := clientFirstMessage(cName, cNonce)
	fmt.Printf("1.C: %s\n", cFirstMessage)
	cNonce = getAttribute(cFirstMessage, byte('r'))
    // 必须要加，因为服务端netty用的是StringDecoder
	cFirstMessage = append(cFirstMessage, "\n"...)
	return cFirstMessage, cNonce
}

func scramSha1FinalMessage(serverFisrtMessage []byte, cname string, cnonce []byte) (out []byte, salt string,
	snonce string, iter int) {
	// server first message e.g.:
	// r=client nonce+server nonce s=server salt i=iterator
	// r=oJnNPGsiuz152d4ba7-d324-4228-8a63-78b352851853,s=b174075f-7512-421c-92ab-81cc1fcf9585,i=4096
	r := regexp.MustCompile(`r=([^,]*),s=([^,]*),i=(.*)$`)
	submatch := r.FindAllStringSubmatch(string(serverFisrtMessage), -1)
	if submatch != nil {
		//fmt.Print(submatch)
		nonce := submatch[0][1]
		salt = submatch[0][2]
		iterator := submatch[0][3]
		//fmt.Println(nonce)
		//fmt.Println(salt)
		//fmt.Println(iterator)
		// 检查nonce是不是以cnonce和snonce连接而成
		cnonceLen := len(cnonce)
		remoteCnonce := nonce[0:cnonceLen]
		if remoteCnonce != string(cnonce) {
			// 认证失败
			return []byte(""), "", "", 0
		}
		snonce = nonce[0+cnonceLen:]
		//fmt.Println(snonce)
		iter, _ := strconv.Atoi(iterator)
		authMessage := authMessage(cname, cnonce, []byte(snonce), ClientHeader, string(serverFisrtMessage))
		//fmt.Println("salt64:" + salt)
		salt = string(fromBase64([]byte(salt)))
		//salt = "15a30400-a9f4-47d6-bcd6-89c47990eebf"
		//fmt.Println("salt:" + salt)
		saltedPassword := pbkdf2Sum(normalize([]byte(ClientPass)), []byte(salt), iter)
		// saltedPassword = []byte("data")
		//fmt.Printf("saltedPassword hex:%x\n", saltedPassword)
		clientKey := hmacSum(saltedPassword, []byte("Client Key"))
		//fmt.Printf("clientKey hex:%x\n", clientKey)
		storedKey := sha1Sum(clientKey)
		//fmt.Printf("storedKey hex:%x\n", storedKey)
		clientSignature := hmacSum(storedKey, authMessage)
		clientProof := xor(clientKey, clientSignature)
		out = clientFinalMessageWithoutProof([]byte(ClientHeader), cnonce, []byte(snonce))
		out = append(out, ",p="...)
		out = append(out, toBase64(clientProof)...)
		// 必须要加，因为服务端netty用的是StringDecoder
		out = append(out, "\n"...)
		return out, salt, snonce, iter
	}
	return
}

func authMessage(cName string, cNonce []byte, sNonce []byte, cHeader string, serverFirstMessage string) (out []byte) {
	out = clientFirstMessageBare([]byte(cName), cNonce)
	out = append(out, ","...)
	out = append(out, serverFirstMessage...)
	out = append(out, ","...)
	out = append(out, clientFinalMessageWithoutProof([]byte(cHeader), cNonce, sNonce)...)
	return
}

func clientFinalMessageWithoutProof(cHeader, cNonce, sNonce []byte) (out []byte) {
	nonce := append(cNonce, sNonce...)

	out = []byte("c=")
	out = append(out, cHeader...)
	out = append(out, ",r="...)
	out = append(out, nonce...)
	return
}

func normalize(in []byte) []byte {
	return in
}
func toBase64(src []byte) []byte {
	out := base64.StdEncoding.EncodeToString(src)
	return []byte(out)
}

func fromBase64(src []byte) []byte {
	dst := make([]byte, base64.StdEncoding.DecodedLen(len(src)))
	l, _ := base64.StdEncoding.Decode(dst, src)
	return dst[:l]
}

func pbkdf2Sum(password, salt []byte, i int) []byte {
	return pbkdf2.Key(password, salt, i, PBKDF2Length, sha1.New)
}

func hmacSum(key, message []byte) []byte {
	mac := hmac.New(sha1.New, key)
	mac.Write(message)
	return mac.Sum(nil)
}

func sha1Sum(message []byte) []byte {
	mac := sha1.New()
	mac.Write(message)
	return mac.Sum(nil)
}

func xor(a, b []byte) []byte {
	if len(a) != len(b) {
		fmt.Println("Warning: xor lengths are differing...", a, b)
	}

	count := len(a)
	if len(b) < count {
		count = len(b)
	}

	out := make([]byte, count)
	for i := 0; i < count; i++ {
		out[i] = a[i] ^ b[i]
	}
	return out
}

func isValidServer(cName string, cPass []byte, cNonce []byte, sNonce []byte, sSalt string, cHeader string, serverSignature []byte, iterations int,
	serverFirstMessage string) bool {
	authMessage := authMessage(cName, cNonce, sNonce, cHeader, serverFirstMessage)
	fmt.Println("--------------------------")
	fmt.Println("salt:", sSalt)
	saltedPassword := pbkdf2Sum(normalize(cPass), []byte(sSalt), iterations)
	serverKey := hmacSum(saltedPassword, []byte("Server Key"))
	serverKey = sha1Sum(serverKey)
	fmt.Printf("saltedPassword hex:%x\n", saltedPassword)
	fmt.Printf("serverKey hex:%x\n", serverKey)
	fmt.Printf("serverSignature hex:%x\n", serverSignature)
	attemptingServerSignature := hmacSum(serverKey, authMessage)
	valid := bytes.Equal(attemptingServerSignature, serverSignature)
	return valid
}
